      
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from PIL import Image
from torch.utils.data import Dataset
from transformers import Qwen2VLForConditionalGeneration

# from math_verify import parse, verify
# from trainer import Qwen2VLGRPOTrainer, GRPOConfig
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
from transformers import TrainingArguments
import yaml
import json
import random
import math

# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, \
    apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple
from torchvision.transforms import ToTensor, ToPILImage
import numpy as np


import spacy
import string
from sentence_transformers import SentenceTransformer, util

nlp = spacy.load("en_core_web_lg")

def debug_point():
    import torch.distributed as dist
    if dist.get_rank() == 0:
        import ipdb
        ipdb.set_trace()
    dist.barrier()


class BidirectionalEmbeddingMatching:
    def __init__(self,
                 model_name="/mnt/bn/pistis/liutao.0220/MODEL/all-mpnet-base-v2"):
        """
        初始化 BEM 相似度计算模型
        :param model_name: 预训练 Transformer 模型名称 (支持 BERT, RoBERTa, DeBERTa 等)
        """
        self.model = SentenceTransformer(model_name).cuda()

    def get_token_embeddings(self, sentence):
        """
        获取句子的 token 级别嵌入
        :param sentence: 输入文本
        :return: token 级别嵌入 (Tensor) 和句子嵌入 (Tensor)
        """
        tokens = self.model.tokenize([sentence])
        for k, v in tokens.items():
            tokens[k] = v.to(self.model.device)
        outputs = self.model(tokens, output_value="token_embeddings", convert_to_tensor=True)
        return outputs  # shape: [seq_len, hidden_dim]

    def max_pooling_similarity(self, emb1, emb2):
        """
        计算双向最大池化的 token 级别相似度
        :param emb1: 句子 1 的 token 嵌入
        :param emb2: 句子 2 的 token 嵌入
        :return: BEM 相似度分数
        """
        # 计算余弦相似度矩阵
        cosine_sim_matrix = util.pytorch_cos_sim(emb1["sentence_embedding"], emb2["sentence_embedding"])

        # 前向匹配（sentence1 → sentence2）
        forward_match = torch.max(cosine_sim_matrix, dim=1)[0]  # 每个 token 在 sentence2 上的最大相似度
        forward_score = torch.mean(forward_match)  # 取均值

        # 反向匹配（sentence2 → sentence1）
        backward_match = torch.max(cosine_sim_matrix, dim=0)[0]  # 每个 token 在 sentence1 上的最大相似度
        backward_score = torch.mean(backward_match)  # 取均值

        # BEM 最终相似度
        bem_similarity = (forward_score + backward_score) / 2
        return bem_similarity.item()

    def compute_similarity(self, sentence1, sentence2):
        """
        计算两个句子的 BEM 相似度
        :param sentence1: 句子 1
        :param sentence2: 句子 2
        :return: 相似度分数 (0~1)
        """
        emb1 = self.get_token_embeddings(sentence1)
        emb2 = self.get_token_embeddings(sentence2)
        return self.max_pooling_similarity(emb1, emb2)


bem = BidirectionalEmbeddingMatching()


def yes_or_no_check(pred, gt):
    if 'yes' in pred.lower() and 'yes' in gt.lower():
        return True

    if 'no' in gt.lower() and 'no' in pred.lower():
        return True

    return False


def multiple_choice_check(pred, gt, ques):
    if 'choices:' not in ques.lower():
        return False

    def extract_choice(text):
        # 去除所有标点符号，只保留字母和空格
        text = re.sub(r'[^\w\s]', '', text)

        # 匹配 ' A ', ' B ', ' C ', ' D ' 格式的正确选项
        match = re.search(r'\s([ABCDEFGHIJK])\s', text)
        return match.group(1) if match else None

    gt_choice = extract_choice(gt)  # 提取 Ground Truth 选项
    pred_choice = pred.strip().upper()  # 处理 pred，去空格并转换为大写

    return gt_choice is not None and pred_choice == gt_choice


def correctness_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'
    for content, sol, pro in zip(contents, solution, kwargs['prompts']):
        reward = 0.0
        # Try symbolic verification first
        try:
            content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()

                if content_answer.strip().lower() == sol.strip().lower() or \
                        content_answer.strip().lower().replace('.', '') == sol.replace('.', '').strip().lower():
                    reward = 1.0

                if multiple_choice_check(content_answer.strip(), sol.strip(), pro):
                    reward = 1.0

                if yes_or_no_check(content_answer.strip(), sol.strip()):
                    reward = 1.0

        except Exception:
            pass  # Continue to next verification method if this fails

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} correctness_reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")
    return rewards

def correctness_score_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'
    for content, sol, pro in zip(contents, solution, kwargs['prompts']):
        reward = 0.0
        # Try symbolic verification first
        try:
            content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()

                def remove_punctuation(text):
                    return text.translate(str.maketrans('', '', string.punctuation))

                pred = remove_punctuation(content_answer)
                gt = remove_punctuation(sol)

                pred_doc = nlp(pred)
                gt_doc = nlp(gt)

                pred_tokens = {token.text.lower() for token in pred_doc if token.is_alpha or token.is_digit}
                gt_tokens = {token.text.lower() for token in gt_doc if token.is_alpha or token.is_digit}

                # 计算匹配的词语数量
                matching_tokens = pred_tokens.intersection(gt_tokens)

                # 计算相似度
                similarity = len(matching_tokens) / len(gt_tokens) if len(gt_tokens) > 0 else 0

                reward = similarity

                if multiple_choice_check(content_answer.strip(), sol.strip(), pro):
                    reward = 1.0

                if yes_or_no_check(content_answer.strip(), sol.strip()):
                    reward = 1.0

        except Exception:
            pass  # Continue to next verification method if this fails

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} correctness_score_reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")
    return rewards

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    pattern = r"<think>([\s\S]*?)</think>\s*<answer>([\s\S]*?)</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


def format_plain_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    pattern = r"<answer>([\s\S]*?)</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]



def correctness_bem_score_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'
    for content, sol, pro in zip(contents, solution, kwargs['prompts']):
        reward = 0.0
        # Try symbolic verification first
        try:
            content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()

                similarity = bem.compute_similarity(content_answer, sol)

                if similarity >= 0.6:
                    reward = similarity

                if multiple_choice_check(content_answer.strip(), sol.strip(), pro[0]['content'][1]['text']):
                    reward = 1.0

                if yes_or_no_check(content_answer.strip(), sol.strip()):
                    reward = 1.0

        except Exception as e:
            print(e)
            pass  # Continue to next verification method if this fails

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} correctness_bem_score_reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")
    return rewards



def noisy_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'
    for content, sol, pro, n_r in zip(contents, solution, kwargs['prompts'], kwargs['noise_extents']):
        reward = 1 - n_r

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Noisy reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")

    return rewards


def noisy_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'
    for content, sol, pro, n_r in zip(contents, solution, kwargs['prompts'], kwargs['noise_extents']):
        reward = 1 - n_r

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Noisy reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")

    return rewards

def noisy_cond_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>([\s\S]*?)</answer>'

    clean_content = []
    current_clean = None
    for n_r, pred in zip(kwargs['noise_extents'], contents):
        if n_r == 0.0:
            current_clean = pred
        assert current_clean is not None
        clean_content.append(current_clean)

    for content, sol, pro, n_r, c_content in zip(contents, solution, kwargs['prompts'], kwargs['noise_extents'], clean_content):

        similarity = bem.compute_similarity(c_content, content)

        noise_extent = n_r * (1 - similarity)
        reward = (2 - noise_extent) / 2

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Noisy reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"clean Content: {c_content}\n====================\n")
                f.write(f"SIM: {similarity}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")

    return rewards


def noisy_verify_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")

    lambda_ = 5
    for content, sol, pro, n_r in zip(contents, solution, kwargs['prompts'], kwargs['noise_extents']):        
        weights = []
        score = []
        for nz, c in zip(kwargs['noise_extents'], contents):
            base_value = 1 - nz
            similarity = bem.compute_similarity(content, c)
            weights.append(math.exp(lambda_ * similarity))
            score.append(base_value)

        weighted_score = sum(S * w for S, w in zip(score, weights)) / sum(weights)
        reward = weighted_score

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Noisy reward: {reward} -------------\n")
                f.write(f"Prompt: {pro[0]['content'][1]['text']}\n====================\n")
                f.write(f"Content: {content}\n====================\n")
                f.write(f"Weights: {weights}\n====================\n")
                f.write(f"Score: {score}\n====================\n")
                f.write(f"Solution: {sol}\n====================\n")

    return rewards

def correctness_thinking_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<think>([\s\S]*?)</think>'

    rewards_per_func = kwargs['rewards_per_func'][:, :2].sum(dim=1)
    correct_index = rewards_per_func==2
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    think_list = []
    for content, sol, pro, correct in zip(contents, solution, kwargs['prompts'], correct_index):
        reward = 0.0
        # Try symbolic verification first
        try:
            content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
            if content_answer_match:
                content_answer = content_answer_match.group(0).strip()
                think_list.append(content_answer)
            else:
                think_list.append(content)
        except Exception as e:
            print(e)
            think_list.append(content)

    n = len(think_list)
    sim_matrix = np.zeros((n, n))
    batch_size = 4
    # 分批次计算上三角部分（包括对角线）
    for i in range(0, n, batch_size):
        for j in range(i, n, batch_size):
            # 获取当前批次的范围
            i_end = min(i + batch_size, n)
            j_end = min(j + batch_size, n)
            
            # 生成当前批次的所有文本对
            batch_pairs = [
                (think_list[x], think_list[y]) 
                for x in range(i, i_end) 
                for y in range(j, j_end)
            ]
            
            # 批量计算相似度
            with torch.no_grad():
                batch_sims = [
                    bem.compute_similarity(pair[0], pair[1]) 
                    for pair in batch_pairs
                ]
            batch_sims = np.array(batch_sims).reshape(i_end-i, j_end-j)
            
            # 填充矩阵
            sim_matrix[i:i_end, j:j_end] = batch_sims
            if i != j:  # 对称填充下三角
                sim_matrix[j:j_end, i:i_end] = batch_sims.T

    correct_indices = np.where(correct_index.cpu().numpy())[0]  # 转换为numpy数组
    if len(correct_indices) == 0:
        rewards = [0] * n
    elif len(correct_indices) == n:
        rewards = [1] * n
    else:
        n_pos = len(correct_indices)
        n_neg = n - n_pos

        # 动态构建掩码矩阵
        pos_mask = np.zeros((n,n))
        for i in correct_indices:
            for j in correct_indices:
                pos_mask[i,j] = 1  # 正样本之间的位置设为1

        neg_mask = 1 - pos_mask  # 负样本掩码取反
        np.fill_diagonal(neg_mask, 0)  # 排除对角线（自身比较）

        # 计算每个thinking的对比奖励
        temperature = 0.07  # 可调节的超参数
        rewards = []
        for i in range(n):
            sim_i = sim_matrix[i]
            # 获取当前样本类型（正/负）
            is_positive = i in correct_indices
            
            # 正样本相似度（始终排除自身）
            pos_sim = []
            if len(correct_indices) > 0:
                pos_sim = sim_i[correct_indices]
                if is_positive:  # 如果是正样本，需排除自己
                    self_pos_idx = np.where(correct_indices == i)[0][0]
                    pos_sim = np.delete(pos_sim, self_pos_idx)
            pos_mean = np.mean(pos_sim) if len(pos_sim) > 0 else 0
            
            # 负样本相似度（始终排除自身）
            neg_sim = []
            neg_indices = [x for x in range(n) if x not in correct_indices]
            if len(neg_indices) > 0:
                neg_sim = sim_i[neg_indices]
                if not is_positive:  # 如果是负样本，需排除自己
                    self_neg_idx = neg_indices.index(i)
                    neg_sim = np.delete(neg_sim, self_neg_idx)
            neg_mean = np.mean(neg_sim) if len(neg_sim) > 0 else 0
            
            if is_positive:
                reward = 1 - neg_mean
            else:
                reward = 1 - pos_mean
            rewards.append(reward)
    # 将奖励转换为tensor
    rewards = torch.tensor(rewards, device=correct_index.device)
    # if kwargs['accelerator'].is_main_process:
    #     import ipdb; ipdb.set_trace()
    return rewards

reward_funcs_registry = {
    "correctness": correctness_reward,
    "correctness_score": correctness_score_reward,
    "format": format_reward,
    "format_plain": format_plain_reward,
    'correctness_bem_score': correctness_bem_score_reward,
    'noisy': noisy_reward,
    'noisy_cond': noisy_cond_reward,
    'noisy_verify': noisy_verify_reward
}

    